{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s8Qn4neXQY4o"
      },
      "source": [
        "# SpICE Classifier model (Pytorch Usage)\n",
        "Licensed under the Apache License, Version 2.0\n",
        "\n",
        "Paper: Speech Intelligibility Classifiers from Half-a-Million Utterances\n",
        "\n",
        "This colab walks through how to download and use the SpICE wav2vec2 based speech intelligibility classifier. This colab walks you through how to use the model on a sample audio file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VJ4PjhssKrlv"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "\n",
        "import os\n",
        "import numpy as np\n",
        "import pickle\n",
        "import tensorflow as tf\n",
        "\n",
        "import IPython\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import requests\n",
        "import torch\n",
        "import torchaudio\n",
        "\n",
        "from google.colab import drive\n",
        "\n",
        "torch.random.manual_seed(0)\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "print(torch.__version__)\n",
        "print(torchaudio.__version__)\n",
        "print(device)\n",
        "\n",
        "SPEECH_URL = \"https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav\"  # noqa: E501\n",
        "SPEECH_FILE = \"_assets/speech.wav\"\n",
        "EXP_SAMPLE_RATE = 16000\n",
        "\n",
        "if not os.path.exists(SPEECH_FILE):\n",
        "    os.makedirs(\"_assets\", exist_ok=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x5rexoNRXJ-G"
      },
      "outputs": [],
      "source": [
        "!wget \"https${SPEECH_URL}\"\n",
        "!mv Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav _assets/speech.wav"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zuWBcwb_Y0LC"
      },
      "outputs": [],
      "source": [
        "#@title Give permissions to access Google Drive\n",
        "drive.mount('/content/gdrive')\n",
        "MODEL_HOME = \"/content/gdrive/MyDrive/euphonia/spice-w2v2-models/\" #@param"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ABF8lAneZEcs"
      },
      "outputs": [],
      "source": [
        "spice_w2v2_cls_model = torch.jit.load(f'{MODEL_HOME}/SpICE_w2v2_cls_scripted.pt')\n",
        "spice_w2v2_cls_model.eval()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9W3V7L5zCZfg"
      },
      "outputs": [],
      "source": [
        "def get_waveform(fpath):\n",
        "  waveform, sample_rate = torchaudio.load(fpath)\n",
        "  waveform = waveform.to(device)\n",
        "\n",
        "  if sample_rate != EXP_SAMPLE_RATE:\n",
        "    waveform = torchaudio.functional.resample(waveform, sample_rate, EXP_SAMPLE_RATE)\n",
        "  return waveform\n",
        "\n",
        "def get_prediction(fpath):\n",
        "  waveform = get_waveform(fpath)\n",
        "  with torch.inference_mode():\n",
        "    output = spice_w2v2_cls_model(waveform)\n",
        "    return output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IGSgalAdZa03"
      },
      "outputs": [],
      "source": [
        "get_prediction(SPEECH_FILE)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "19OL5PBwrXiPx-qR3DkJAxhjTLV_JcDyT",
          "timestamp": 1674515436547
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
